import torch
from PIL import ImageOps
from torchvision.transforms import functional as TF


def interp(t):
    return 3 * t**2 - 2 * t**3


def perlin(width, height, scale=10, device=None):
    gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)
    xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)
    ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)
    wx = 1 - interp(xs)
    wy = 1 - interp(ys)
    dots = 0
    dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
    dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
    dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
    dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))
    return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)


def perlin_ms(octaves, width, height, grayscale, device):
    out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]
    # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
    for i in range(1 if grayscale else 3):
        scale = 2 ** len(octaves)
        oct_width = width
        oct_height = height
        for oct in octaves:
            p = perlin(oct_width, oct_height, scale, device)
            out_array[i] += p * oct
            scale //= 2
            oct_width *= 2
            oct_height *= 2
    return torch.cat(out_array)


def create_perlin_noise(octaves, width, height, grayscale, side_y, side_x, device):
    out = perlin_ms(octaves, width, height, grayscale, device)
    if grayscale:
        out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))
        out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')
    else:
        out = out.reshape(-1, 3, out.shape[0] // 3, out.shape[1])
        out = TF.resize(size=(side_y, side_x), img=out)
        out = TF.to_pil_image(out.clamp(0, 1).squeeze())

    out = ImageOps.autocontrast(out)
    return out


def regen_perlin(perlin_mode, side_y, side_x, device, batch_size):
    if perlin_mode == 'color':
        init = create_perlin_noise(
            [1.5**-i * 0.5 for i in range(12)], 1, 1, False, side_y, side_x, device
        )
        init2 = create_perlin_noise(
            [1.5**-i * 0.5 for i in range(8)], 4, 4, False, side_y, side_x, device
        )
    elif perlin_mode == 'gray':
        init = create_perlin_noise(
            [1.5**-i * 0.5 for i in range(12)], 1, 1, True, side_y, side_x, device
        )
        init2 = create_perlin_noise(
            [1.5**-i * 0.5 for i in range(8)], 4, 4, True, side_y, side_x, device
        )
    else:
        init = create_perlin_noise(
            [1.5**-i * 0.5 for i in range(12)], 1, 1, False, side_y, side_x, device
        )
        init2 = create_perlin_noise(
            [1.5**-i * 0.5 for i in range(8)], 4, 4, True, side_y, side_x, device
        )

    init = (
        TF.to_tensor(init)
        .add(TF.to_tensor(init2))
        .div(2)
        .to(device)
        .unsqueeze(0)
        .mul(2)
        .sub(1)
    )
    del init2
    return init.expand(batch_size, -1, -1, -1)
